AlgoWiki

Wavelet tree

A wavelet tree is a static data structure on an integer sequence that answers a rich family of range queries — kk-th smallest, count of elements less than a value, range frequency, range median, and more — each in O(logσ)O(\log \sigma) time, where σ\sigma is the size of the value domain. Building the tree takes O(nlogσ)O(n \log \sigma) time and space. In competitive programming σ\sigma is almost always reduced to nn by coordinate compression, so both bounds become O(nlogn)O(n \log n)

The wavelet tree is closely related to the persistent segment tree, which answers the same queries with the same asymptotic complexity. In practice the wavelet tree is faster (better cache behaviour, no pointer overhead) and significantly simpler to implement correctly.

Description

Given an array A[1n]A[1 \ldots n] of integers drawn from the range [lo,hi][\mathit{lo}, \mathit{hi}], a wavelet tree is a complete binary tree over the value range. Each node covers a value sub-range [lo,hi][\mathit{lo}, \mathit{hi}]. At every internal node with mid=(lo+hi)/2\mathit{mid} = \lfloor(\mathit{lo}+\mathit{hi})/2\rfloor

  • Elements with value mid\le \mathit{mid} are routed to the left child
  • Elements with value >mid> \mathit{mid} are routed to the right child

The node stores a prefix-count array cnt\mathit{cnt}, where cnt[i]\mathit{cnt}[i] is the number of the first ii elements (in the node's local ordering) that were routed left. With these counts every query can navigate the tree without storing the values explicitly.

Construction

struct WaveletTree {
    int lo, hi;
    WaveletTree *left = nullptr, *right = nullptr;
    vector<int> cnt; // cnt[i] = # of first i elements routed to left child

    // Build over A[from, to) with value range [lo, hi].
    // The array is reordered in place during construction.
    void build(int *from, int *to, int lo, int hi) {
        this->lo = lo; this->hi = hi;
        if (from >= to || lo == hi) return;
        int mid = lo + (hi - lo) / 2;
        cnt.reserve(to - from + 1);
        cnt.push_back(0);
        for (auto it = from; it != to; ++it)
            cnt.push_back(cnt.back() + (*it <= mid));
        auto pivot = stable_partition(from, to,
                         [mid](int x){ return x <= mid; });
        left  = new WaveletTree(); left->build(from, pivot, lo, mid);
        right = new WaveletTree(); right->build(pivot, to, mid+1, hi);
    }

stable_partition physically splits the array so that left-routed elements come first; the left child then builds over that prefix, the right child over the remainder.

Queries

All queries take a 1-indexed range [l,r][l, r]. At each node, two values summarize the range:

lb=cnt[l1],rb=cnt[r]\mathit{lb} = \mathit{cnt}[l-1], \quad \mathit{rb} = \mathit{cnt}[r]

rblb\mathit{rb} - \mathit{lb} elements of [l,r][l,r] went left; the rest went right. In the left child the range maps to [lb+1,rb][\mathit{lb}+1, \mathit{rb}]; in the right child it maps to [llb,  rrb][l - \mathit{lb},\; r - \mathit{rb}]

kk-th smallest (kth(l, r, k)): At each node count how many elements in [l,r][l,r] went left. If kk does not exceed that count, recurse left; otherwise subtract it from kk and recurse right. At a leaf the value is determined.

    // k-th smallest element in A[l..r] (1-indexed, k is 1-based)
    int kth(int l, int r, int k) {
        if (lo == hi) return lo;
        int lb = cnt[l-1], rb = cnt[r];
        int inLeft = rb - lb;
        if (k <= inLeft) return left->kth(lb+1, rb, k);
        return right->kth(l-lb, r-rb, k-inLeft);
    }

Count less than (countLess(l, r, v)): returns {i[l,r]:A[i]<v}|\{i \in [l,r] : A[i] < v\}|. If vv is entirely outside [lo,hi][\mathit{lo}, \mathit{hi}] the answer is trivial. Otherwise, if vmidv \le \mathit{mid} the right subtree can contribute nothing (all its values are >midv> \mathit{mid} \ge v), so we recurse left only. If v>midv > \mathit{mid} all left elements are <v< v, so they all count, plus a recursive call into the right subtree.

    // # elements in A[l..r] with value < v
    int countLess(int l, int r, int v) {
        if (v <= lo) return 0;
        if (v > hi)  return r - l + 1;
        int mid = lo + (hi - lo) / 2;
        int lb = cnt[l-1], rb = cnt[r];
        if (v <= mid) return left->countLess(lb+1, rb, v);
        return (rb - lb) + right->countLess(l-lb, r-rb, v);
    }
};

Range frequency of value vv in [l,r][l,r]: countLess(l, r, v+1) - countLess(l, r, v)

Range median: kth(l, r, (r - l + 2) / 2)

Complexity

Each query descends exactly one root-to-leaf path of the value-range tree, which has height log2σ\lceil\log_2 \sigma\rceil. Each step is O(1)O(1). Hence every query is O(logσ)O(\log \sigma)

Building involves one stable_partition pass at every level of the tree. Each element participates in exactly one pass per level, giving O(nlogσ)O(n \log \sigma) total time. The cnt\mathit{cnt} arrays across all nodes at any single level together hold exactly n+1n+1 integers, so total space is also O(nlogσ)O(n \log \sigma)

Applications

  • Range order statistics — finding the kk-th smallest element in a subarray; counting how many elements fall in a value range [a,b][a, b] (as countLess(l, r, b+1) - countLess(l, r, a)). See Order statistics tree for the single-element analogue.
  • Sliding window median — for a window of fixed size kk the median is kth(i, i+k-1, (k+1)/2), answered in O(logσ)O(\log \sigma) per window position instead of the O(logn)O(\log n) per insertion of a balanced BST approach (with a better constant and no pointer overhead).
  • Range count-distinct — number of distinct values in A[lr]A[l \ldots r]. Define prev[i]\mathit{prev}[i] = last position before ii with the same value (0 if none). Build the wavelet tree on the prev\mathit{prev} array. Then the count of distinct values in [l,r][l, r] equals the count of positions i[l,r]i \in [l,r] with prev[i]<l\mathit{prev}[i] < l, which is exactly countLess(l, r, l) on the prev\mathit{prev} wavelet tree — answered in O(logn)O(\log n)
  • Predecessor / successor in a range — the largest value v\le v in A[lr]A[l \ldots r] can be found with a binary-search variant of countLess: if there is any element v\le v, compute k=k =countLess(l, r, v+1) and return kth(l, r, k)
  • Bitvector rank/select — at each tree level the cnt\mathit{cnt} array is a rank structure; the wavelet tree generalises this to multi-valued alphabets, which underpins many suffix array and compressed index applications.

Variants

Coordinate compression

When values can be up to 10910^9, map them to [1,n][1, n] before building the tree:

// Returns a wavelet tree over A[0..n-1], compressed to [1, n].
// sorted_vals receives the sorted unique values so kth results can be decoded.
WaveletTree* buildCompressed(vector<int> &A, vector<int> &sorted_vals) {
    sorted_vals = A;
    sort(sorted_vals.begin(), sorted_vals.end());
    sorted_vals.erase(unique(sorted_vals.begin(), sorted_vals.end()),
                      sorted_vals.end());
    int sigma = sorted_vals.size();
    vector<int> C = A;
    for (int &x : C)
        x = lower_bound(sorted_vals.begin(), sorted_vals.end(), x)
            - sorted_vals.begin() + 1;
    WaveletTree *wt = new WaveletTree();
    wt->build(C.data(), C.data() + C.size(), 1, sigma);
    return wt;
}
// Decode: original value = sorted_vals[ wt->kth(l, r, k) - 1 ]
// countLess for original threshold x:
//   threshold = lower_bound(sorted_vals, x) - sorted_vals.begin() + 1
//   wt->countLess(l, r, threshold)

Array-based (flat) implementation

The pointer-based tree above allocates a node per split. A flat layout stores all cnt\mathit{cnt} arrays level by level in a single vector, which improves cache performance and removes allocation overhead:

struct WaveletFlat {
    int n, lo, hi, levels;
    vector<vector<int>> cnt; // cnt[d][i] at depth d

    WaveletFlat(vector<int> A, int lo, int hi) : n(A.size()), lo(lo), hi(hi) {
        levels = 1;
        while ((1 << levels) < hi - lo + 1) levels++;
        cnt.assign(levels + 1, vector<int>(n + 1, 0));
        vector<int> cur = A, nxt(n);
        for (int d = 0; d < levels; d++) {
            int range_lo = lo, range_hi = hi;
            // midpoint for the root at this level needs per-node tracking;
            // a simpler approach flattens by bit from MSB to LSB:
            int bit = levels - 1 - d;
            for (int i = 0; i < n; i++)
                cnt[d][i+1] = cnt[d][i] + !((A[i] - lo) >> bit & 1);
            // partition stably by bit `bit`
            int li = 0, ri = cnt[d][n];
            for (int i = 0; i < n; i++) {
                if (!((A[i] - lo) >> bit & 1)) nxt[li++] = A[i];
                else                           nxt[ri++] = A[i];
            }
            swap(A, nxt);
        }
    }
};

The flat implementation avoids new entirely and is typically 2–3× faster in practice.

Range sum augmentation

To answer "sum of the kk smallest elements in A[lr]A[l \ldots r]", augment each node with a parallel sum prefix array alongside cnt, accumulating the values of elements that go left. A sumKSmallest(l, r, k) query mirrors kth: when kk elements fit in the left subtree, return the left subtree's answer; otherwise add the full left sum and recurse right for the remainder. The kk-th smallest value times its count plus sums of smaller elements can all be computed without changing the O(logσ)O(\log \sigma) per-query bound.

Problems

kk-th order statistics

Solution sketch — K-th Number (MKTHNUM)

This is the canonical wavelet tree problem: given A[1n]A[1 \ldots n] and mm queries (l,r,k)(l, r, k), output the kk-th smallest value in A[lr]A[l \ldots r]

Coordinate-compress the values to [1,n][1, n], build the wavelet tree, and for each query call kth(l, r, k) in O(logn)O(\log n). Total time O((n+m)logn)O((n + m)\log n)

Solution sketch — Sliding Median (CSES 1076)

For a sliding window of width kk, the median position is p=(k+1)/2p = (k+1)/2. After building the wavelet tree on the full array, each window [i,i+k1][i, i+k-1] contributes one query kth(i, i+k-1, p). Each query is O(logn)O(\log n), giving O(nlogn)O(n \log n) total — the same as a balanced BST sliding window but with a much smaller constant.

Range counting

Solution sketch — K Query (KQUERY)

Each query asks: how many elements in A[ij]A[i \ldots j] are greater thankk

This is (ji+1)countLess(i,j,k+1)(j - i + 1) - \mathtt{countLess}(i,\, j,\, k+1). After coordinate-compressing, countLess(i, j, threshold) where threshold is the rank of k+1k+1 in the sorted values runs in O(logn)O(\log n) per query.

Count-distinct and range partitioning

Solution sketch — Till I Collapse (CF 786C)

For each k=1nk = 1 \ldots n, compute f(k)f(k): the number of groups produced by greedily splitting AA into maximal-length contiguous segments each with k\le k distinct values.

Build a wavelet tree on the previous-occurrence arrayprev\mathit{prev}, where prev[i]\mathit{prev}[i] is the last index <i< i with A[prev[i]]=A[i]A[\mathit{prev}[i]] = A[i] (or 00 if none). Then the number of distinct values in A[lr]A[l \ldots r] is countLess(l, r, l) on this tree — a single O(logn)O(\log n) call.

For a fixed kk, simulate the greedy: starting at l=1l = 1, binary search for the rightmost rr such that countDistinct(l, r)k\le k, advance l=r+1l = r+1, and increment the group counter. Each greedy step costs O(log2n)O(\log^2 n), and there are at most n/k\lceil n/k \rceil steps per kk, giving an overall O(nlog2nk=1n1/k)=O(nlog3n)O(n \log^2 n \sum_{k=1}^{n} 1/k) = O(n \log^3 n) bound when done offline in the right order — or O(nnlogn)O(n \sqrt{n} \log n) with the standard divide-and-conquer on kk

See also

  • Persistent segment tree — same asymptotic complexity; wavelet tree is the offline, space-efficient alternative
  • Merge sort tree — stores a sorted list at each segment-tree node; answers the same queries in O(log2n)O(\log^2 n) but is simpler to extend to updates
  • Segment tree — the underlying range-decomposition idea
  • Order statistics tree — handles point insertions/deletions but lacks arbitrary range queries
  • Coordinate compression — essential preprocessing step when values exceed nn